Skip to content

Conversation

@guyueh1
Copy link
Contributor

@guyueh1 guyueh1 commented Feb 5, 2026

What does this PR do ?

In nemo_gym environment, when generation_logprobs contains NaN, retry rollout.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features

    • Improved NeMo-Gym rollout robustness: rollout operations now automatically retry when NaN values are detected in generation log probabilities. A new configurable parameter controls the maximum number of retry attempts.
  • Tests

    • Added test coverage for the rollout retry mechanism.

@guyueh1 guyueh1 requested review from a team as code owners February 5, 2026 17:01
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 5, 2026

📝 Walkthrough

Walkthrough

The changes introduce retry logic to NeMo-Gym rollout collection to handle NaN values in generation log probabilities. A new configuration field enables retrying rollouts up to a maximum count, with NaN detection in generation_logprobs triggering automatic retries until valid results are obtained or the retry limit is reached.

Changes

Cohort / File(s) Summary
Configuration Update
examples/nemo_gym/run_grpo_nemo_gym.py
Added rollout_max_retries_to_avoid_lp_nan keyword argument to NemoGymConfig initialization, sourced from configuration with a default value of 1.
Core Retry Logic
nemo_rl/environments/nemo_gym.py
Added rollout_max_retries_to_avoid_lp_nan configuration field to NemoGymConfig and reworked run_rollouts method to implement a while loop that retries rollouts when NaN values are detected in generation_logprobs, incrementing a trial counter and re-executing until valid results are obtained or max retries are reached.
Test Coverage
tests/unit/environments/test_nemo_gym.py
Introduced module-level constant NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN and invocation counter, added nemo_gym_with_patched_run_examples fixture that patches RolloutCollectionHelper.run_examples to inject NaN values and track calls, and added test_nemo_gym_rollout_max_retries_to_avoid_lp_nan test to verify retry behavior matches configuration.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 44.44% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR introduces major feature with core rollout logic changes but lacks documented test results or testing information in description. Add test execution results to PR description demonstrating all tests pass. Include regression testing documentation, particularly for convergence and numerical stability impacts.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately captures the primary change: adding retry logic when generation_logprobs contains NaN, which is implemented across all modified files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
examples/nemo_gym/run_grpo_nemo_gym.py (1)

252-257: ⚠️ Potential issue | 🟠 Major

Avoid code-level default for rollout_max_retries_to_avoid_lp_nan.

Using cfg.get(..., 1) introduces a hidden default in code; please read the key directly and define the default in YAML.

🔧 Suggested fix
-        rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg.get("rollout_max_retries_to_avoid_lp_nan", 1),
+        rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg["rollout_max_retries_to_avoid_lp_nan"],

As per coding guidelines, "YAML is the single source of truth for configuration defaults; do not set non-None defaults in code for configuration values" and "Access required config values directly (e.g., policy_cfg['precision']) and assume they are present; do not introduce hidden defaults in code."

nemo_rl/environments/nemo_gym.py (1)

27-31: ⚠️ Potential issue | 🟠 Major

Remove the TypedDict default and mark the field as required.

rollout_max_retries_to_avoid_lp_nan: int = 1 sets a non-None default in code, which violates the guideline that YAML is the single source of truth for configuration defaults. In TypedDict, this pattern is also inconsistent with the codebase convention of using NotRequired[int] for optional fields. Since the field is always explicitly provided at instantiation sites (never relying on the class default), it should be declared as required without a default value.

🔧 Suggested fix
 class NemoGymConfig(TypedDict):
     model_name: str
     base_urls: List[str]
     initial_global_config_dict: Dict[str, Any]
-    rollout_max_retries_to_avoid_lp_nan: int = 1
+    rollout_max_retries_to_avoid_lp_nan: int
🤖 Fix all issues with AI agents
In `@nemo_rl/environments/nemo_gym.py`:
- Around line 115-152: Validate and clearly define the semantics of max_retries
before entering the loop: ensure cfg["rollout_max_retries_to_avoid_lp_nan"] is
an int >= 1 (or raise a ValueError) so nemo_gym_num_rows is always defined; keep
current semantics as "max attempts" by replacing the while trial < max_retries
loop with a for attempt in range(max_retries): or explicitly document that
max_retries is the total number of attempts, and remove the off‑by‑one
ambiguity; reference the variables max_retries, trial, the while trial <
max_retries loop, and nemo_gym_num_rows when making the check and adjustment.

In `@tests/unit/environments/test_nemo_gym.py`:
- Around line 206-208: Add a Google-style docstring to the pytest fixture
nemo_gym_with_patched_run_examples describing its purpose, parameters (if any)
and return value; place it immediately below the def
nemo_gym_with_patched_run_examples(...) line and follow Google style sections
(Args:, Returns:) and mention that it yields a nemo_gym instance with
RolloutCollectionHelper.run_examples patched for tests so readers and Sphinx can
parse it.
- Around line 247-295: The fixture currently calls context.__enter__ and yields
env but if actor creation or setup fails the function exits before calling
context.__exit__, leaking the patch; wrap the setup and yield in a try/finally
so context.__exit__ is always called: call context.__enter__ first, then create
config and env via NemoGym.options(...).remote and perform
ray.get(env.health_check.remote()) inside the try, yield env as before, and in
the finally ensure you call env.shutdown.remote() and ray.kill(env) only if env
was created, then call context.__exit__(None, None, None) to guarantee the
patch_run_examples context is reverted even on failures.
- Around line 46-47: The mutable module-global run_examples_called should follow
the project's naming and test-safety conventions: rename it to
G_RUN_EXAMPLES_CALLED (upper snake case with G_ prefix) and ensure it's reset
before each test to avoid cross-test leakage; update all references to
run_examples_called in this file (e.g., where it is incremented or asserted) to
G_RUN_EXAMPLES_CALLED and add a pytest fixture or test setup that sets
G_RUN_EXAMPLES_CALLED = 0 before each test runs.
- Around line 222-239: The patched new_run_examples is unpacking awaitables
returned by orig_run_examples and yielding tuples, causing await task to fail in
run_rollouts; instead, for each awaitable "task" from orig_run_examples(self,
examples, head_server_config) create and yield a new async wrapper coroutine (or
future) that awaits the original task, injects the NaN into
result["response"]["output"] (preserve the has_generation_log_probs check and
raise ValueError if none), and then returns the (row, result) pair—i.e., leave
orig_run_examples and run_rollouts semantics intact by yielding an awaitable
that performs the mutation after awaiting the original "task".

@guyueh1 guyueh1 self-assigned this Feb 5, 2026
@guyueh1 guyueh1 added the CI:L2 Run doctests, unit tests, functional tests, and convergence tests label Feb 9, 2026
Signed-off-by: Guyue Huang <[email protected]>
@guyueh1 guyueh1 added CI:L2 Run doctests, unit tests, functional tests, and convergence tests and removed CI:L2 Run doctests, unit tests, functional tests, and convergence tests labels Feb 9, 2026
Signed-off-by: Guyue Huang <[email protected]>
@guyueh1 guyueh1 added CI:L2 Run doctests, unit tests, functional tests, and convergence tests and removed CI:L2 Run doctests, unit tests, functional tests, and convergence tests labels Feb 9, 2026
@guyueh1 guyueh1 added CI:L2 Run doctests, unit tests, functional tests, and convergence tests and removed CI:L2 Run doctests, unit tests, functional tests, and convergence tests labels Feb 10, 2026
@guyueh1
Copy link
Contributor Author

guyueh1 commented Feb 10, 2026

@terrykong the pipeline has passed at commit a3cd107 can you review?

Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally lgtm. small comments

model_name=policy_generation.cfg["model_name"],
base_urls=policy_generation.dp_openai_server_base_urls,
initial_global_config_dict=config["env"]["nemo_gym"],
rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg.get(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couple of things:

  1. can we move the default to the yaml and avoid defaulting here
  2. can we add it to the generation config typeddict with docstring
  3. can we add asserts (==1) that say this has no effect in other places like (when nemo-gym isn't used) just so users are not under the impression that other paths are less stable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about moving rollout_max_retries_to_avoid_lp_nan to gym env config instead of putting it at generation config? since it's only used for gym and it's implemented in gym env.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefe yuki's idea and I now make it as env.nemo_gym.rollout_max_attempts_to_avoid_lp_nan

nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result(
nemo_gym_result, tokenizer
)
max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also maybe

assert self.cfg["rollout_max_attempts_to_avoid_lp_nan"] >= 1, .....

just to give a nice user error instead of skipping this while loop

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also maybe good to put this assert at init part instead of here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the assert in init

Copy link
Contributor

@yuki-97 yuki-97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm and left minor comments.

model_name=policy_generation.cfg["model_name"],
base_urls=policy_generation.dp_openai_server_base_urls,
initial_global_config_dict=config["env"]["nemo_gym"],
rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg.get(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about moving rollout_max_retries_to_avoid_lp_nan to gym env config instead of putting it at generation config? since it's only used for gym and it's implemented in gym env.

nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result(
nemo_gym_result, tokenizer
)
max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also maybe good to put this assert at init part instead of here.

@guyueh1 guyueh1 added CI:L2 Run doctests, unit tests, functional tests, and convergence tests and removed CI:L2 Run doctests, unit tests, functional tests, and convergence tests labels Feb 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L2 Run doctests, unit tests, functional tests, and convergence tests super-v3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants